#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author ：hhx
@Date ：2022/5/22 14:56 
@Description ：UNet测试
"""
import numpy as np
import os
from utils import *
import torch
from torch import nn, optim
from torch.utils import data
from models import AE, AE_withLinear, UNet
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from sklearn.cluster import KMeans

torch.set_default_tensor_type(torch.DoubleTensor)
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

device = 'cpu'
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if __name__ == '__main__':
    batch = 1
    datasetpath = 'G:\哨兵2号数据'
    trainSet = CarTiffDateSet(datasetpath, type='test')
    train_loader = torch.utils.data.DataLoader(dataset=trainSet,
                                               batch_size=batch,
                                               shuffle=False)
    model = UNet.UNet().to(device)
    model.load_state_dict(torch.load("SavedModels/UNet.pkl"))
    model.eval()

    temp = np.empty([72, 2048])

    cloudList = [0, 1, 2, 3, 4, 5, 9, 17, 24, 37, 39, 40]
    greenList = [6, 7, 8]
    for index, images in enumerate(tqdm(train_loader)):
        images = images.to(device)
        outputs2 = model(images).detach().numpy()
        fea = model.x5.detach().numpy().flatten()
        # print(fea.shape)
        temp[index] = fea

    #     if index in cloudList:
    #         plt.scatter(fea[0][0], fea[0][1], label=index, c='red')
    #     if index >= 66:
    #         plt.scatter(fea[0][0], fea[0][1], label=index, c='green')
    #     if index in greenList:
    #         plt.scatter(fea[0][0], fea[0][1], label=index, c='blue')
    # plt.legend()
    # plt.show()


    # plt.show()

        # if index >= 66 or index == 24 or index==5 or index==6 or index==7:
        if index in cloudList:
            plt.plot(fea, label=index, c='red')
        if index >= 66:
            plt.plot(fea, label=index, c='green')
        if index in greenList:
            plt.plot(fea, label=index, c='blue')
    plt.legend()
    plt.show()
    kmeans = KMeans(n_clusters=4, random_state=0, algorithm="auto")
    kmeans.fit(temp)
    print(kmeans.labels_)
    np.save('data.npy', temp)
